Adapted from https://github.com/Jackson-Kang/Pytorch-VAE-tutorial
Consider to download this Jupyter Notebook and run locally, or test it with Colab.
Install required packages
!pip install torch!pip install matplotlib
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.12/site-packages (2.8.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.13.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from torch) (75.1.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.12/site-packages (from torch) (2024.6.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2->torch) (2.1.3)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.23 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
# Install required packages
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npfrom tqdm import tqdmfrom torchvision.utils import save_image, make_grid
# Model Hyperparametersdataset_path ='~/datasets'DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")batch_size =100# note that VQ-VAE enbmeds spatial information, thus needs a CNN# the latent space also perserve spatial informationinput_dim =1hidden_dim =32latent_dim =16n_embeddings=32# the length of the codebookoutput_dim =1commitment_beta =0.25lr =2e-4epochs =2# just for illusration...print_step =100
/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
0%| | 0/100 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
0%| | 0/100 [00:01<?, ?it/s]